// Copyright (c) Sean Lawlor
//
// This source code is licensed under both the MIT license found in the
// LICENSE-MIT file in the root directory of this source tree.

//! A clone of the [super::auth_handshake] test but with encryped communications
//!
//! Encryption certificates used are the same as [rustls]'s examples: <https://github.com/rustls/rustls>

use std::convert::TryFrom;
use std::fs::File;
use std::io::{self, BufReader};
use std::path::{Path, PathBuf};
use std::sync::Arc;

use clap::Args;
use ractor::concurrency::{sleep, Duration, Instant};
use ractor::Actor;
use rustls_pemfile::{certs, rsa_private_keys};
use tokio_rustls::rustls::{Certificate, OwnedTrustAnchor, PrivateKey};
use tokio_rustls::{TlsAcceptor, TlsConnector};

const AUTH_TIME_ALLOWANCE_MS: u128 = 1500;

/// Configuration
#[derive(Args, Debug, Clone)]
pub struct EncryptionConfig {
    /// Server port
    server_port: u16,
    /// If specified, represents the client to connect to
    client_port: Option<u16>,
    /// If specified, represents the client to connect to
    client_host: Option<String>,
}

fn load_certs(path: &Path) -> io::Result<Vec<Certificate>> {
    certs(&mut BufReader::new(File::open(path)?))
        .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
        .map(|mut certs| certs.drain(..).map(Certificate).collect())
}

fn load_keys(path: &Path) -> io::Result<Vec<PrivateKey>> {
    rsa_private_keys(&mut BufReader::new(File::open(path)?))
        .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
        .map(|mut keys| keys.drain(..).map(PrivateKey).collect())
}

pub async fn test(config: EncryptionConfig) -> i32 {
    let cookie = "cookie".to_string();
    let hostname = "localhost".to_string();

    // ================== Server TLS Configuration ================== //
    // Example `rustls` command: cargo run --bin tlsserver-mio -- --certs test-ca/rsa/end.fullchain --key test-ca/rsa/end.rsa -p 8443 echo
    //
    // combined with source code: https://github.com/tokio-rs/tls/blob/357bc562483dcf04c1f8d08bd1a831b144bf7d4c/tokio-rustls/examples/server/src/main.rs
    let cert_path = PathBuf::from("test-ca/rsa/end.fullchain");
    let key_path = PathBuf::from("test-ca/rsa/end.rsa");
    let certs = load_certs(&cert_path).expect("Failed to load encryption certificates");
    let mut keys = load_keys(&key_path).expect("Failed to load encryption keys");

    let server_config = tokio_rustls::rustls::ServerConfig::builder()
        .with_safe_defaults()
        .with_no_client_auth()
        .with_single_cert(certs, keys.remove(0))
        .expect("Failed to build server configuration");
    let acceptor = TlsAcceptor::from(Arc::new(server_config));

    // ================== Client TLS Configuration ================== //

    let ca_path = PathBuf::from("test-ca/rsa/ca.cert");
    let mut ca_pem = BufReader::new(File::open(ca_path).expect("Failed to load CA certificate"));
    let ca_certs = rustls_pemfile::certs(&mut ca_pem).expect("Failed to parse CA certificate");

    let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
    let trust_anchors = ca_certs.iter().map(|cert| {
        let ta =
            webpki::TrustAnchor::try_from_cert_der(&cert[..]).expect("Failed to build TrustAnchor");
        tracing::warn!(
            "CA Cert SUB={}",
            String::from_utf8(ta.subject.to_vec()).unwrap_or("n/a".to_string())
        );
        OwnedTrustAnchor::from_subject_spki_name_constraints(
            ta.subject,
            ta.spki,
            ta.name_constraints,
        )
    });
    root_cert_store.add_trust_anchors(trust_anchors);
    let client_config = tokio_rustls::rustls::ClientConfig::builder()
        .with_safe_defaults()
        .with_root_certificates(root_cert_store)
        .with_no_client_auth();
    let connector = TlsConnector::from(Arc::new(client_config));

    // NOTE: It's `testserver.com` because that's what's generated by the rustls team. Eventually we should re-generate
    // our own certs but this is just a temporary hack for the test
    let domain = tokio_rustls::rustls::ServerName::try_from("testserver.com")
        .expect("Invalid DNS name `node-a`");

    // ================== Server Creation ================== //

    let server = ractor_cluster::NodeServer::new(
        config.server_port,
        cookie,
        super::random_name(),
        hostname.clone(),
        Some(ractor_cluster::IncomingEncryptionMode::Tls(acceptor)),
        None,
    );

    tracing::info!("Starting NodeServer on port {}", config.server_port);

    let (actor, handle) = Actor::spawn(None, server, ())
        .await
        .expect("Failed to start NodeServer");

    if let (Some(client_host), Some(client_port)) = (config.client_host, config.client_port) {
        tracing::info!("Connecting to remote NodeServer at {client_host}:{client_port}");
        if let Err(error) = ractor_cluster::client_connect_enc(
            &actor,
            format!("{client_host}:{client_port}"),
            connector,
            domain,
        )
        .await
        {
            tracing::error!("Failed to connect with error {error}");
            return -3;
        } else {
            tracing::info!("Client connected NodeServer b to NodeServer a");
        }
    }

    let mut err_code = -1;
    tracing::info!("Waiting for NodeSession status updates");

    let mut rpc_reply = ractor::call_t!(actor, ractor_cluster::NodeServerMessage::GetSessions, 200);
    let mut tic = None;

    while rpc_reply.is_ok() {
        if let Some(timestamp) = tic {
            let time: Duration = Instant::now() - timestamp;
            if time.as_millis() > AUTH_TIME_ALLOWANCE_MS {
                err_code = -2;
                tracing::error!(
                    "The authentcation time has been going on for over > {}ms. Failing the test",
                    time.as_millis()
                );
                break;
            }
        }

        if let Some(item) = rpc_reply
            .unwrap()
            .into_values()
            .collect::<Vec<_>>()
            .first()
            .cloned()
        {
            // we got an actor, track how long it took to auth, maxing out at 500ms
            if tic.is_none() {
                tic = Some(Instant::now());
            }

            let is_authenticated = ractor::call_t!(
                item.actor,
                ractor_cluster::NodeSessionMessage::GetAuthenticationState,
                200
            );
            match is_authenticated {
                Err(err) => {
                    tracing::warn!("NodeSession returned error on rpc query {err}");
                    break;
                }
                Ok(false) => {
                    // Still waiting
                }
                Ok(true) => {
                    err_code = 0;
                    tracing::info!("Authentication succeeded. Exiting test");
                    break;
                }
            }
        }
        // try again
        rpc_reply = ractor::call_t!(actor, ractor_cluster::NodeServerMessage::GetSessions, 200);
    }

    tracing::info!("Terminating test - code {err_code}");

    sleep(Duration::from_millis(250)).await;

    // cleanup
    actor.stop(None);
    handle.await.unwrap();

    err_code
}
